{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Splitting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import myutils\n",
    "\n",
    "path2data = \"./data\"\n",
    "sub_folder_jpg = \"hmdb51_jpg\"\n",
    "path2ajpgs = os.path.join(path2data, sub_folder_jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_vids, all_labels, catgs = myutils.get_vids(path2ajpgs) \n",
    "len(all_vids), len(all_labels), len(catgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_vids[:1], all_labels[:3], catgs[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_dict = {}\n",
    "ind = 0\n",
    "for uc in catgs:\n",
    "    labels_dict[uc] = ind\n",
    "    ind+=1\n",
    "labels_dict "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_classes =5\n",
    "unique_ids = [id_ for id_, label in zip(all_vids,all_labels) if labels_dict[label]<num_classes]\n",
    "unique_labels = [label for id_, label in zip(all_vids,all_labels) if labels_dict[label]<num_classes]\n",
    "len(unique_ids),len(unique_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "\n",
    "sss = StratifiedShuffleSplit(n_splits=2, test_size=0.1, random_state=0)\n",
    "train_indx, test_indx = next(sss.split(unique_ids, unique_labels))\n",
    "\n",
    "train_ids = [unique_ids[ind] for ind in train_indx]\n",
    "train_labels = [unique_labels[ind] for ind in train_indx]\n",
    "print(len(train_ids), len(train_labels)) \n",
    "\n",
    "test_ids = [unique_ids[ind] for ind in test_indx]\n",
    "test_labels = [unique_labels[ind] for ind in test_indx]\n",
    "print(len(test_ids), len(test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ids[:5], train_labels[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ids[:5], test_labels[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader, Subset\n",
    "import glob\n",
    "from PIL import Image\n",
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "np.random.seed(2020)\n",
    "random.seed(2020)\n",
    "torch.manual_seed(2020)\n",
    "\n",
    "class VideoDataset(Dataset):\n",
    "    def __init__(self, ids, labels, transform):      \n",
    "        self.transform = transform\n",
    "        self.ids = ids\n",
    "        self.labels = labels\n",
    "    def __len__(self):\n",
    "        return len(self.ids)\n",
    "    def __getitem__(self, idx):\n",
    "        path2imgs=glob.glob(self.ids[idx]+\"/*.jpg\")\n",
    "        path2imgs = path2imgs[:timesteps]\n",
    "        label = labels_dict[self.labels[idx]]\n",
    "        frames = []\n",
    "        for p2i in path2imgs:\n",
    "            frame = Image.open(p2i)\n",
    "            frames.append(frame)\n",
    "        \n",
    "        seed = np.random.randint(1e9)        \n",
    "        frames_tr = []\n",
    "        for frame in frames:\n",
    "            random.seed(seed)\n",
    "            np.random.seed(seed)\n",
    "            frame = self.transform(frame)\n",
    "            frames_tr.append(frame)\n",
    "        if len(frames_tr)>0:\n",
    "            frames_tr = torch.stack(frames_tr)\n",
    "        return frames_tr, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# choose one\n",
    "model_type = \"3dcnn\"\n",
    "model_type = \"rnn\"    \n",
    "\n",
    "timesteps =16\n",
    "if model_type == \"rnn\":\n",
    "    h, w =224, 224\n",
    "    mean = [0.485, 0.456, 0.406]\n",
    "    std = [0.229, 0.224, 0.225]\n",
    "else:\n",
    "    h, w = 112, 112\n",
    "    mean = [0.43216, 0.394666, 0.37645]\n",
    "    std = [0.22803, 0.22145, 0.216989]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "\n",
    "train_transformer = transforms.Compose([\n",
    "            transforms.Resize((h,w)),\n",
    "            transforms.RandomHorizontalFlip(p=0.5),  \n",
    "            transforms.RandomAffine(degrees=0, translate=(0.1,0.1)),    \n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "            ])     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = VideoDataset(ids= train_ids, labels= train_labels, transform= train_transformer)\n",
    "print(len(train_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs, label = train_ds[10]\n",
    "imgs.shape, label, torch.min(imgs), torch.max(imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pylab as plt\n",
    "%matplotlib inline\n",
    "\n",
    "plt.figure(figsize=(10,10))\n",
    "for ii,img in enumerate(imgs[::4]):\n",
    "    plt.subplot(2,2,ii+1)\n",
    "    plt.imshow(myutils.denormalize(img, mean, std))\n",
    "    plt.title(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pylab as plt\n",
    "%matplotlib inline\n",
    "\n",
    "plt.figure(figsize=(20,20))\n",
    "plt.subplots_adjust(wspace=0, hspace=0)\n",
    "for ii,img in enumerate(imgs[::2]):\n",
    "    plt.subplot(1,8,ii+1)\n",
    "    plt.imshow(myutils.denormalize(img, mean, std))\n",
    "    plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_transformer = transforms.Compose([\n",
    "            transforms.Resize((h,w)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean, std),\n",
    "            ]) \n",
    "test_ds = VideoDataset(ids= test_ids, labels= test_labels, transform= test_transformer)\n",
    "print(len(test_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs, label = test_ds[5]\n",
    "imgs.shape, label, torch.min(imgs), torch.max(imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "for ii,img in enumerate(imgs[::4]):\n",
    "    plt.subplot(2,2,ii+1)\n",
    "    plt.imshow(myutils.denormalize(img, mean, std))\n",
    "    plt.title(label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining Data Loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_fn_r3d_18(batch):\n",
    "    imgs_batch, label_batch = list(zip(*batch))\n",
    "    imgs_batch = [imgs for imgs in imgs_batch if len(imgs)>0]\n",
    "    label_batch = [torch.tensor(l) for l, imgs in zip(label_batch, imgs_batch) if len(imgs)>0]\n",
    "    imgs_tensor = torch.stack(imgs_batch)\n",
    "    imgs_tensor = torch.transpose(imgs_tensor, 2, 1)\n",
    "    labels_tensor = torch.stack(label_batch)\n",
    "    return imgs_tensor,labels_tensor\n",
    "\n",
    "def collate_fn_rnn(batch):\n",
    "    imgs_batch, label_batch = list(zip(*batch))\n",
    "    imgs_batch = [imgs for imgs in imgs_batch if len(imgs)>0]\n",
    "    label_batch = [torch.tensor(l) for l, imgs in zip(label_batch, imgs_batch) if len(imgs)>0]\n",
    "    imgs_tensor = torch.stack(imgs_batch)\n",
    "    labels_tensor = torch.stack(label_batch)\n",
    "    return imgs_tensor,labels_tensor\n",
    "    \n",
    "\n",
    "batch_size = 1\n",
    "if model_type == \"rnn\":\n",
    "    train_dl = DataLoader(train_ds, batch_size= batch_size,\n",
    "                          shuffle=True, collate_fn= collate_fn_rnn)\n",
    "    test_dl = DataLoader(test_ds, batch_size= 2*batch_size,\n",
    "                         shuffle=False, collate_fn= collate_fn_rnn)  \n",
    "else:\n",
    "    train_dl = DataLoader(train_ds, batch_size= batch_size, \n",
    "                          shuffle=True, collate_fn= collate_fn_r3d_18)\n",
    "    test_dl = DataLoader(test_ds, batch_size= 2*batch_size, \n",
    "                         shuffle=False, collate_fn= collate_fn_r3d_18)          "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for xb,yb in train_dl:\n",
    "    print(xb.shape, yb.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for xb,yb in test_dl:\n",
    "    print(xb.shape, yb.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "class Resnt18Rnn(nn.Module):\n",
    "    def __init__(self, params_model):\n",
    "        super(Resnt18Rnn, self).__init__()\n",
    "        num_classes = params_model[\"num_classes\"]\n",
    "        dr_rate= params_model[\"dr_rate\"]\n",
    "        pretrained = params_model[\"pretrained\"]\n",
    "        rnn_hidden_size = params_model[\"rnn_hidden_size\"]\n",
    "        rnn_num_layers = params_model[\"rnn_num_layers\"]\n",
    "        \n",
    "        baseModel = models.resnet18(pretrained=pretrained)\n",
    "        num_features = baseModel.fc.in_features\n",
    "        baseModel.fc = Identity()\n",
    "        self.baseModel = baseModel\n",
    "        self.dropout= nn.Dropout(dr_rate)\n",
    "        self.rnn = nn.LSTM(num_features, rnn_hidden_size, rnn_num_layers)\n",
    "        self.fc1 = nn.Linear(rnn_hidden_size, num_classes)\n",
    "    def forward(self, x):\n",
    "        b_z, ts, c, h, w = x.shape\n",
    "        ii = 0\n",
    "        y = self.baseModel((x[:,ii]))\n",
    "        output, (hn, cn) = self.rnn(y.unsqueeze(1))\n",
    "        for ii in range(1, ts):\n",
    "            y = self.baseModel((x[:,ii]))\n",
    "            out, (hn, cn) = self.rnn(y.unsqueeze(1), (hn, cn))\n",
    "        out = self.dropout(out[:,-1])\n",
    "        out = self.fc1(out) \n",
    "        return out \n",
    "    \n",
    "class Identity(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Identity, self).__init__()\n",
    "    def forward(self, x):\n",
    "        return x    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import models\n",
    "from torch import nn\n",
    "\n",
    "if model_type == \"rnn\":\n",
    "    params_model={\n",
    "        \"num_classes\": num_classes,\n",
    "        \"dr_rate\": 0.1,\n",
    "        \"pretrained\" : True,\n",
    "        \"rnn_num_layers\": 1,\n",
    "        \"rnn_hidden_size\": 100,}\n",
    "    model = Resnt18Rnn(params_model)        \n",
    "else:\n",
    "    model = models.video.r3d_18(pretrained=True, progress=False)\n",
    "    num_features = model.fc.in_features\n",
    "    model.fc = nn.Linear(num_features, num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    if model_type==\"rnn\":\n",
    "        x = torch.zeros(1, 16, 3, h, w)\n",
    "    else:\n",
    "        x = torch.zeros(1, 3, 16, h, w)\n",
    "    y= model(x)\n",
    "    print(y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path2weights = \"./models/weights.pt\"\n",
    "torch.save(model.state_dict(), path2weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau\n",
    "\n",
    "loss_func = nn.CrossEntropyLoss(reduction=\"sum\")\n",
    "opt = optim.Adam(model.parameters(), lr=3e-5)\n",
    "lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=5,verbose=1)\n",
    "os.makedirs(\"./models\", exist_ok=True)\n",
    "\n",
    "params_train={\n",
    "    \"num_epochs\": 20,\n",
    "    \"optimizer\": opt,\n",
    "    \"loss_func\": loss_func,\n",
    "    \"train_dl\": train_dl,\n",
    "    \"val_dl\": test_dl,\n",
    "    \"sanity_check\": True,\n",
    "    \"lr_scheduler\": lr_scheduler,\n",
    "    \"path2weights\": \"./models/weights_\"+model_type+\".pt\",\n",
    "    }\n",
    "model,loss_hist,metric_hist = myutils.train_val(model,params_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myutils.plot_loss(loss_hist, metric_hist)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
